from typing import List, Optional, Tuple, Union
import numpy as np

from gptcache.utils import import_weaviate
from gptcache.utils.log import gptcache_log
from gptcache.manager.vector_data.base import VectorBase, VectorData

import_weaviate()

from weaviate import Client
from weaviate.auth import AuthCredentials
from weaviate.config import Config
from weaviate.embedded import EmbeddedOptions
from weaviate.types import NUMBERS


class Weaviate(VectorBase):
    """
    vector store: Weaviate
    """

    TIMEOUT_TYPE = Union[Tuple[NUMBERS, NUMBERS], NUMBERS]

    def __init__(
        self,
        url: Optional[str] = None,
        auth_client_secret: Optional[AuthCredentials] = None,
        timeout_config: TIMEOUT_TYPE = (10, 60),
        proxies: Union[dict, str, None] = None,
        trust_env: bool = False,
        additional_headers: Optional[dict] = None,
        startup_period: Optional[int] = 5,
        embedded_options: Optional[EmbeddedOptions] = None,
        additional_config: Optional[Config] = None,
        class_name: str = "GPTCache",
        class_schema: dict = None,
        top_k: Optional[int] = 1,
    ) -> None:

        if url is None and embedded_options is None:
            embedded_options = EmbeddedOptions()

        self.client = Client(
            url=url,
            auth_client_secret=auth_client_secret,
            timeout_config=timeout_config,
            proxies=proxies,
            trust_env=trust_env,
            additional_headers=additional_headers,
            startup_period=startup_period,
            embedded_options=embedded_options,
            additional_config=additional_config,
        )

        if class_schema:
            self.class_schema = class_schema
            self.class_name = class_schema.get("class")
        else:
            self.class_name = class_name
            self.class_schema = self._get_default_class_schema()

        self._create_class()
        self.top_k = top_k

    def _create_class(self):
        if self.client.schema.exists(self.class_name):
            gptcache_log.warning(
                "The %s collection already exists, and it will be used directly.",
                self.class_name,
            )
        else:
            self.client.schema.create_class(self.class_schema)
        return self.class_name

    def _get_default_class_schema(self) -> dict:
        return {
            "class": self.class_name,
            "description": "LLM response cache",
            "properties": [
                {
                    "name": "data_id",
                    "dataType": ["int"],
                    "description": "The data-id generated by GPTCache for vectors.",
                }
            ],
            "vectorIndexConfig": {"distance": "cosine"},
        }

    def mul_add(self, datas: List[VectorData]):
        with self.client.batch(batch_size=100, dynamic=True) as batch:
            for data in datas:
                properties = {
                    "data_id": data.id,
                }

                batch.add_data_object(
                    data_object=properties, class_name=self.class_name, vector=data.data
                )

    def search(self, data: np.ndarray, top_k: int = -1):
        if top_k == -1:
            top_k = self.top_k

        result = (
            self.client.query.get(class_name=self.class_name, properties=["data_id"])
            .with_near_vector(content={"vector": data})
            .with_additional(["distance"])
            .with_limit(top_k)
            .do()
        )

        return list(
            map(
                lambda x: (x["_additional"]["distance"], x["data_id"]),
                result["data"]["Get"][self.class_name],
            )
        )

    def _get_uuids(self, data_ids):
        uuid_list = []

        for data_id in data_ids:
            res = (
                self.client.query.get(
                    class_name=self.class_name, properties=["data_id"]
                )
                .with_where(
                    {"path": ["data_id"], "operator": "Equal", "valueInt": data_id}
                )
                .with_additional(["id"])
                .do()
            )

            uuid_list.append(
                res["data"]["Get"][self.class_name][0]["_additional"]["id"]
            )

        return uuid_list

    def delete(self, ids):
        uuids = self._get_uuids(ids)

        for uuid in uuids:
            self.client.data_object.delete(class_name=self.class_name, uuid=uuid)

    def rebuild(self, ids=None):
        return

    def flush(self):
        self.client.batch.flush()

    def close(self):
        self.flush()

    def get_embeddings(self, data_id: int):
        results = (
            self.client.query.get(class_name=self.class_name, properties=["data_id"])
            .with_where(
                {
                    "path": ["data_id"],
                    "operator": "Equal",
                    "valueInt": data_id,
                }
            )
            .with_additional(["vector"])
            .with_limit(1)
            .do()
        )

        results = results["data"]["Get"][self.class_name]

        if len(results) < 1:
            return None

        vec_emb = np.asarray(results[0]["_additional"]["vector"], dtype="float32")
        return vec_emb

    def update_embeddings(self, data_id: int, emb: np.ndarray):
        self.delete([data_id])

        properties = {
            "data_id": data_id,
        }

        self.client.data_object.create(
            data_object=properties, class_name=self.class_name, vector=emb
        )
